from dataclasses import asdict
from itertools import combinations
from typing import Callable

import jsonschema
import pytest

import tests.functional.services.policy_engine.utils.api as policy_engine_api
from tests.functional.services.policy_engine.conftest import is_legacy_provider
from tests.functional.services.policy_engine.utils.utils import (
    ImagesByVulnerabilityQuery,
    ImagesByVulnerabilityQueryOptions,
)

arg_combination_tests = [
    ImagesByVulnerabilityQuery(
        "CVE-2011-3374",
        ["sha256:406413437f26223183d133ccc7186f24c827729e1b21adc7330dd43fcdc030b3"],
        ImagesByVulnerabilityQueryOptions(
            "Negligible", "debian:10", "libapt-pkg5.0", False
        ),
    ),
    ImagesByVulnerabilityQuery(
        "CVE-2020-8177",
        ["sha256:fe3ca35038008b0eac0fa4e686bd072c9430000ab7d7853001bde5f5b8ccf60c"],
        ImagesByVulnerabilityQueryOptions("Medium", "rhel:7", "curl", False),
    ),
    ImagesByVulnerabilityQuery(
        "GHSA-5mg8-w23w-74h3",
        [
            "sha256:80a31c3ce2e99c3691c27ac3b1753163214494e9b2ca07bfdccf29a5cca2bfbe",
            "sha256:406413437f26223183d133ccc7186f24c827729e1b21adc7330dd43fcdc030b3",
            "sha256:fe3ca35038008b0eac0fa4e686bd072c9430000ab7d7853001bde5f5b8ccf60c",
        ],
        ImagesByVulnerabilityQueryOptions("Medium", "github:java", "guava", False),
    ),
]


class TestQueryByVulnerability:
    @classmethod
    def _test_query_by_vulnerability(
        cls,
        query: ImagesByVulnerabilityQuery,
        schema_validator: Callable[[str], jsonschema.Draft7Validator],
    ) -> None:
        # get images by vulnerability and check that response code is 200
        get_image_resp = policy_engine_api.users.get_images_by_vulnerability(
            vulnerability_id=query.vulnerability_id, **asdict(query.query_metadata)
        )
        assert get_image_resp.code == 200
        # validate response schema
        query_by_vuln_schema_validator = schema_validator(
            "query_by_vulnerability.schema.json"
        )
        is_valid: bool = query_by_vuln_schema_validator.is_valid(get_image_resp.body)
        assert is_valid, "\n".join(
            [
                str(e)
                for e in query_by_vuln_schema_validator.iter_errors(get_image_resp.body)
            ]
        )
        # check that the image digests returned match the expected
        results = set(
            [
                image["image"]["imageDigest"]
                for image in get_image_resp.body["vulnerable_images"]
            ]
        )
        assert results == set(query.affected_images)

    @pytest.mark.parametrize(
        "query",
        [
            ImagesByVulnerabilityQuery(
                "CVE-2013-2512",
                [
                    "sha256:80a31c3ce2e99c3691c27ac3b1753163214494e9b2ca07bfdccf29a5cca2bfbe",
                    "sha256:406413437f26223183d133ccc7186f24c827729e1b21adc7330dd43fcdc030b3",
                    "sha256:fe3ca35038008b0eac0fa4e686bd072c9430000ab7d7853001bde5f5b8ccf60c",
                ],
            ),
            ImagesByVulnerabilityQuery(
                "GHSA-h6q6-9hqw-rwfv",
                [
                    "sha256:80a31c3ce2e99c3691c27ac3b1753163214494e9b2ca07bfdccf29a5cca2bfbe",
                    "sha256:406413437f26223183d133ccc7186f24c827729e1b21adc7330dd43fcdc030b3",
                    "sha256:fe3ca35038008b0eac0fa4e686bd072c9430000ab7d7853001bde5f5b8ccf60c",
                ],
            ),
            ImagesByVulnerabilityQuery(
                "CVE-2021-21330" if is_legacy_provider() else "GHSA-v6wp-4m6f-gcjg",
                [
                    "sha256:80a31c3ce2e99c3691c27ac3b1753163214494e9b2ca07bfdccf29a5cca2bfbe",
                    "sha256:406413437f26223183d133ccc7186f24c827729e1b21adc7330dd43fcdc030b3",
                    "sha256:fe3ca35038008b0eac0fa4e686bd072c9430000ab7d7853001bde5f5b8ccf60c",
                ],
            ),
            ImagesByVulnerabilityQuery(
                "CVE-2020-8908" if is_legacy_provider() else "GHSA-5mg8-w23w-74h3",
                [
                    "sha256:80a31c3ce2e99c3691c27ac3b1753163214494e9b2ca07bfdccf29a5cca2bfbe",
                    "sha256:406413437f26223183d133ccc7186f24c827729e1b21adc7330dd43fcdc030b3",
                    "sha256:fe3ca35038008b0eac0fa4e686bd072c9430000ab7d7853001bde5f5b8ccf60c",
                ],
            ),
            ImagesByVulnerabilityQuery(
                "CVE-2020-8177",
                [
                    "sha256:fe3ca35038008b0eac0fa4e686bd072c9430000ab7d7853001bde5f5b8ccf60c"
                ],
            ),
            ImagesByVulnerabilityQuery(
                "CVE-2020-8177",
                [],
                ImagesByVulnerabilityQueryOptions(severity="High"),
            ),
            ImagesByVulnerabilityQuery(
                "CVE-2020-8177",
                [],
                ImagesByVulnerabilityQueryOptions(namespace="debian:10"),
            ),
            ImagesByVulnerabilityQuery(
                "CVE-2020-8177",
                [],
                ImagesByVulnerabilityQueryOptions(affected_package="gnupg2-2.0.22"),
            ),
            ImagesByVulnerabilityQuery("CVE-2020-0000", []),
        ],
    )
    def test_query_by_vulnerability(self, schema_validator, setup_all_images, query):
        self._test_query_by_vulnerability(query, schema_validator)

    @pytest.mark.parametrize("query", arg_combination_tests)
    def test_query_by_vulnerability_arg_combinations(
        self,
        schema_validator,
        setup_all_images,
        query,
    ):
        # Where n is the number of args that can be passed to the endpoint
        # Generate every k combination for k in {1, ... |n|} and create a new ImagesByVulnerabilityQuery
        query_fields = [
            (key, value) for key, value in asdict(query.query_metadata).items()
        ]
        for combination_length in range(1, len(query_fields)):
            for query_combination in combinations(query_fields, combination_length):
                query_dict = {key: value for key, value in query_combination}
                new_query = ImagesByVulnerabilityQuery(
                    query.vulnerability_id,
                    query.affected_images,
                    ImagesByVulnerabilityQueryOptions(**query_dict),
                )
                self._test_query_by_vulnerability(new_query, schema_validator)
